{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Using Functional API to build CNN\n",
    "\n",
    "We start introducing the concept of Functional API as an alternative way of building keras models. In the previous examples on MLP and CNN on MNIST, we used Sequential API. The Sequential API is fine if we are building simple models wherein there is a single point for input and single point for output. In advanced models, this is not sufficient since we build more complex graphs possibly with multiple inputs and outputs. In such cases, Functional API is the method of choice.\n",
    "\n",
    "Functional API builds upon the concept of function composition:\n",
    "\n",
    "\\begin{equation*}\n",
    "y = f_n \\circ f_{n-1} \\circ \\ldots \\circ f_1(x)\n",
    "\\end{equation*}\n",
    "\n",
    "The output of one function becomes the input of the next function and so on. We can also have a function with multiple outputs that become inputs to multiple functions. Or, we can have multiple functional blocks with multiple separate inputs that are combined into one or more outputs. \n",
    "\n",
    "In the following example, we will show how to build a model made of `3-Conv2D-1-Dense` using Functional API. \n",
    "\n",
    "Similar to the previous example on CNN on MNIST, let us do the initializations, and input and label pre-processing. Since this is similar to the CNN model that we built using Sequential API, it is expected that its performance is more or less the same."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "from __future__ import print_function\n",
    "\n",
    "import numpy as np\n",
    "import tensorflow\n",
    "from tensorflow.keras.layers import Dense, Dropout, Input\n",
    "from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten\n",
    "from tensorflow.keras.models import Model\n",
    "from tensorflow.keras.datasets import mnist\n",
    "from tensorflow.keras.utils import to_categorical\n",
    "\n",
    "# load MNIST dataset\n",
    "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
    "\n",
    "# from sparse label to categorical\n",
    "num_labels = len(np.unique(y_train))\n",
    "y_train = to_categorical(y_train)\n",
    "y_test = to_categorical(y_test)\n",
    "\n",
    "# reshape and normalize input images\n",
    "image_size = x_train.shape[1]\n",
    "x_train = np.reshape(x_train,[-1, image_size, image_size, 1])\n",
    "x_test = np.reshape(x_test,[-1, image_size, image_size, 1])\n",
    "x_train = x_train.astype('float32') / 255\n",
    "x_test = x_test.astype('float32') / 255"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Hyper-parameters\n",
    "\n",
    "The hyper-parameters are similar to the one used in CNN on MNIST example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "# network parameters\n",
    "input_shape = (image_size, image_size, 1)\n",
    "batch_size = 128\n",
    "kernel_size = 3\n",
    "filters = 64"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Actual Model Building\n",
    "\n",
    "Using the Functional API Keras layer `y = Conv2D(x)`, we can stack 3 CNN layers together to form a simple backbone network. The `y = MaxPooling2D(x)` is used to compress the learned feature maps. With compressed feature maps, the CNN learns new representations with a bigger receptive field.\n",
    "\n",
    "In Functional API, the output of one layer becomes the input of the next layer. For example if the first layer is `y2 = Conv2D(y1)`, then the next layer is `y3 = Conv2D(y2)`. To save from variable name pollution, we normally reuse the same variable name (eg. `y`) as shown below.\n",
    "\n",
    "In Sequential model building, we use the `add()` method of a model to stack multiple layers together."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "# use functional API to build cnn layers\n",
    "inputs = Input(shape=input_shape)\n",
    "y = Conv2D(filters=filters,\n",
    "           kernel_size=kernel_size,\n",
    "           activation='relu',\n",
    "           padding='same')(inputs)\n",
    "y = MaxPooling2D()(y)\n",
    "y = Conv2D(filters=filters,\n",
    "           kernel_size=kernel_size,\n",
    "           activation='relu',\n",
    "           padding='same')(y)\n",
    "y = MaxPooling2D()(y)\n",
    "y = Conv2D(filters=filters,\n",
    "           kernel_size=kernel_size,\n",
    "           activation='relu',\n",
    "           padding='same')(y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Head is a Dense Layer\n",
    "\n",
    "Since we are doing logistic regression, we need to `flatten` the output of the 3-layer CNN so that we can generate the right number of logits to model a 10-class categorical distribution. This is the same concept that we used in MLP."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"functional_9\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "input_6 (InputLayer)         [(None, 28, 28, 1)]       0         \n",
      "_________________________________________________________________\n",
      "conv2d_15 (Conv2D)           (None, 28, 28, 64)        640       \n",
      "_________________________________________________________________\n",
      "max_pooling2d_10 (MaxPooling (None, 14, 14, 64)        0         \n",
      "_________________________________________________________________\n",
      "conv2d_16 (Conv2D)           (None, 14, 14, 64)        36928     \n",
      "_________________________________________________________________\n",
      "max_pooling2d_11 (MaxPooling (None, 7, 7, 64)          0         \n",
      "_________________________________________________________________\n",
      "conv2d_17 (Conv2D)           (None, 7, 7, 64)          36928     \n",
      "_________________________________________________________________\n",
      "flatten_4 (Flatten)          (None, 3136)              0         \n",
      "_________________________________________________________________\n",
      "dense_4 (Dense)              (None, 10)                31370     \n",
      "=================================================================\n",
      "Total params: 105,866\n",
      "Trainable params: 105,866\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "# image to vector before connecting to dense layer\n",
    "y = Flatten()(y)\n",
    "# dropout regularization\n",
    "#y = Dropout(dropout)(y)\n",
    "outputs = Dense(num_labels, activation='softmax')(y)\n",
    "# build the model by supplying inputs/outputs\n",
    "model = Model(inputs=inputs, outputs=outputs)\n",
    "# network model in text\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model Training and Validation\n",
    "\n",
    "The last step is similar to our MLP example. We compile the model and perform training by calling `fit`. The model validation is the last step."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/20\n",
      "469/469 [==============================] - 46s 98ms/step - loss: 1.4132 - accuracy: 0.5942 - val_loss: 0.4621 - val_accuracy: 0.8452\n",
      "Epoch 2/20\n",
      "469/469 [==============================] - 44s 93ms/step - loss: 0.3605 - accuracy: 0.8910 - val_loss: 0.2893 - val_accuracy: 0.9168\n",
      "Epoch 3/20\n",
      "469/469 [==============================] - 44s 94ms/step - loss: 0.2649 - accuracy: 0.9204 - val_loss: 0.2212 - val_accuracy: 0.9346\n",
      "Epoch 4/20\n",
      "469/469 [==============================] - 44s 94ms/step - loss: 0.2087 - accuracy: 0.9378 - val_loss: 0.1718 - val_accuracy: 0.9486\n",
      "Epoch 5/20\n",
      "469/469 [==============================] - 44s 95ms/step - loss: 0.1693 - accuracy: 0.9498 - val_loss: 0.1573 - val_accuracy: 0.9511\n",
      "Epoch 6/20\n",
      "469/469 [==============================] - 44s 93ms/step - loss: 0.1407 - accuracy: 0.9583 - val_loss: 0.1250 - val_accuracy: 0.9625\n",
      "Epoch 7/20\n",
      "469/469 [==============================] - 44s 94ms/step - loss: 0.1200 - accuracy: 0.9647 - val_loss: 0.1015 - val_accuracy: 0.9690\n",
      "Epoch 8/20\n",
      "469/469 [==============================] - 45s 96ms/step - loss: 0.1050 - accuracy: 0.9693 - val_loss: 0.0914 - val_accuracy: 0.9715\n",
      "Epoch 9/20\n",
      "469/469 [==============================] - 44s 95ms/step - loss: 0.0949 - accuracy: 0.9722 - val_loss: 0.0836 - val_accuracy: 0.9748\n",
      "Epoch 10/20\n",
      "469/469 [==============================] - 44s 94ms/step - loss: 0.0864 - accuracy: 0.9747 - val_loss: 0.0755 - val_accuracy: 0.9757\n",
      "Epoch 11/20\n",
      "469/469 [==============================] - 44s 94ms/step - loss: 0.0796 - accuracy: 0.9764 - val_loss: 0.0689 - val_accuracy: 0.9782\n",
      "Epoch 12/20\n",
      "469/469 [==============================] - 44s 94ms/step - loss: 0.0747 - accuracy: 0.9773 - val_loss: 0.0631 - val_accuracy: 0.9799\n",
      "Epoch 13/20\n",
      "469/469 [==============================] - 44s 94ms/step - loss: 0.0703 - accuracy: 0.9791 - val_loss: 0.0602 - val_accuracy: 0.9810\n",
      "Epoch 14/20\n",
      "469/469 [==============================] - 44s 95ms/step - loss: 0.0668 - accuracy: 0.9801 - val_loss: 0.0594 - val_accuracy: 0.9803\n",
      "Epoch 15/20\n",
      "469/469 [==============================] - 44s 94ms/step - loss: 0.0634 - accuracy: 0.9809 - val_loss: 0.0627 - val_accuracy: 0.9792\n",
      "Epoch 16/20\n",
      "469/469 [==============================] - 44s 94ms/step - loss: 0.0609 - accuracy: 0.9814 - val_loss: 0.0676 - val_accuracy: 0.9774\n",
      "Epoch 17/20\n",
      "469/469 [==============================] - 44s 94ms/step - loss: 0.0588 - accuracy: 0.9825 - val_loss: 0.0501 - val_accuracy: 0.9844\n",
      "Epoch 18/20\n",
      "469/469 [==============================] - 45s 95ms/step - loss: 0.0563 - accuracy: 0.9832 - val_loss: 0.0493 - val_accuracy: 0.9843\n",
      "Epoch 19/20\n",
      "469/469 [==============================] - 44s 94ms/step - loss: 0.0534 - accuracy: 0.9839 - val_loss: 0.0471 - val_accuracy: 0.9844\n",
      "Epoch 20/20\n",
      "469/469 [==============================] - 45s 95ms/step - loss: 0.0525 - accuracy: 0.9840 - val_loss: 0.0460 - val_accuracy: 0.9863\n",
      "79/79 [==============================] - 2s 22ms/step - loss: 0.0460 - accuracy: 0.9863\n",
      "\n",
      "Test accuracy: 98.6%\n"
     ]
    }
   ],
   "source": [
    "# classifier loss, Adam optimizer, classifier accuracy\n",
    "model.compile(loss='categorical_crossentropy',\n",
    "              optimizer='sgd',\n",
    "              metrics=['accuracy'])\n",
    "\n",
    "# train the model with input images and labels\n",
    "model.fit(x_train,\n",
    "          y_train,\n",
    "          validation_data=(x_test, y_test),\n",
    "          epochs=20,\n",
    "          batch_size=batch_size)\n",
    "\n",
    "# model accuracy on test dataset\n",
    "score = model.evaluate(x_test, y_test, batch_size=batch_size)\n",
    "print(\"\\nTest accuracy: %.1f%%\" % (100.0 * score[1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
